/**
 * 
 */
package edu.berkeley.nlp.PCFGLA;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.ScalingTools;

/**
 * @author petrov
 *
 */
public class CoarseToFineMaxRuleDerivationParser extends CoarseToFineMaxRuleParser {

  protected double[][][][] maxcScore;  // start, end, state --> logProb
  protected double[][][][] maxsScore;  // start, end, state --> logProb
  protected int[][][][] maxcSplit;  // start, end, state -> split position
  protected int[][][][] maxcChild;  // start, end, state -> unary child (if any)
  protected int[][][][] maxcChildSub;  // start, end, state -> unary child (if any)
  protected int[][][][] maxcLeftChild;  // start, end, state -> left child
  protected int[][][][] maxcRightChild;  // start, end, state -> right child
  protected int[][][][] maxcLeftChildSub;  // start, end, state -> left child
  protected int[][][][] maxcRightChildSub;  // start, end, state -> right child

	
	
	public CoarseToFineMaxRuleDerivationParser(Grammar gr, Lexicon lex,
			double unaryPenalty, int endL, boolean viterbi, boolean sub,
			boolean score, boolean accurate, boolean variational, boolean useGoldPOS,
			boolean initializeCascade) {
		super(gr, lex, unaryPenalty, endL, viterbi, sub, score, accurate, variational,
				useGoldPOS, initializeCascade);
	}

	
  void doConstrainedMaxCScores(List<String> sentence, Grammar grammar, Lexicon lexicon, final boolean scale) {
  	numSubStatesArray = grammar.numSubStates;
    maxcScore = new double[length][length + 1][numStates][];
    maxcSplit = new int[length][length + 1][numStates][];
    maxcChild      = new int[length][length + 1][numStates][];
    maxcChildSub   = new int[length][length + 1][numStates][];
    maxcLeftChild  = new int[length][length + 1][numStates][];
    maxcRightChild = new int[length][length + 1][numStates][];
    maxcLeftChildSub  = new int[length][length + 1][numStates][];
    maxcRightChildSub = new int[length][length + 1][numStates][];
    
    double initVal = Double.NEGATIVE_INFINITY;
    
		for (int start = 0; start < length; start++) {
			for (int end = start + 1; end <= length; end++) {
				for (int state=0; state<numSubStatesArray.length; state++){
					if (!allowedStates[start][end][state]) continue;
					maxcSplit[start][end][state] = new int[numSubStatesArray[state]];
			    maxcChild[start][end][state] = new int[numSubStatesArray[state]];
			    maxcChildSub[start][end][state] = new int[numSubStatesArray[state]];
			    maxcLeftChild[start][end][state] = new int[numSubStatesArray[state]];
			    maxcRightChild[start][end][state] = new int[numSubStatesArray[state]];
			    maxcLeftChildSub[start][end][state] = new int[numSubStatesArray[state]];
			    maxcRightChildSub[start][end][state] = new int[numSubStatesArray[state]];
			    maxcScore[start][end][state] = new double[numSubStatesArray[state]];
	        Arrays.fill(maxcSplit[start][end][state], -1);
	        Arrays.fill(maxcChild[start][end][state], -1);
	        Arrays.fill(maxcChildSub[start][end][state], -1);
	        Arrays.fill(maxcLeftChild[start][end][state], -1);
	        Arrays.fill(maxcRightChild[start][end][state], -1);
	        Arrays.fill(maxcLeftChildSub[start][end][state], -1);
	        Arrays.fill(maxcRightChildSub[start][end][state], -1);
			    Arrays.fill(maxcScore[start][end][state], initVal);
				}
			}
		}
    
    
    double logNormalizer = iScore[0][length][0][0];
//    double thresh2 = threshold*logNormalizer;
    for (int diff = 1; diff <= length; diff++) {
      //System.out.print(diff + " ");
      for (int start = 0; start < (length - diff + 1); start++) {
        int end = start + diff;
        if (diff > 1) {
          // diff > 1: Try binary rules
        	for (int pState=0; pState<numSubStatesArray.length; pState++){
            if (!allowedStates[start][end][pState]) continue;
            BinaryRule[] parentRules = grammar.splitRulesWithP(pState);
            int nParentStates = numSubStatesArray[pState]; // == scores[0][0].length;
            
            for (int i = 0; i < parentRules.length; i++) {
              BinaryRule r = parentRules[i];
              int lState = r.leftChildState;
              int rState = r.rightChildState;

              int narrowR = narrowRExtent[start][lState];
              boolean iPossibleL = (narrowR < end); // can this left constituent leave space for a right constituent?
              if (!iPossibleL) { continue; }
              
              int narrowL = narrowLExtent[end][rState];
              boolean iPossibleR = (narrowL >= narrowR); // can this right constituent fit next to the left constituent?
              if (!iPossibleR) { continue; }
              
              int min1 = narrowR;
              int min2 = wideLExtent[end][rState];
              int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
              if (min > narrowL) { continue; }
              
              int max1 = wideRExtent[start][lState];
              int max2 = narrowL;
              int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
              if (min > max) { continue; }

              double[][][] scores = r.getScores2();
              int nLeftChildStates = numSubStatesArray[lState]; // == scores.length;
              int nRightChildStates = numSubStatesArray[rState]; // == scores[0].length;
              for (int split = min; split <= max; split++) {
                double ruleScore = 0;
  	      			if (!allowedStates[start][split][lState]) continue;
  	      			if (!allowedStates[split][end][rState]) continue;

               	double scalingFactor = 0.0;
               	if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor(
                		oScale[start][end][pState]+iScale[start][split][lState]+
                		iScale[split][end][rState]-iScale[0][length][0]));
               	
               	
               	for (int lp = 0; lp < nLeftChildStates; lp++) {
                  double lIS = iScore[start][split][lState][lp];
                  if (lIS == 0) continue;
//                  if (lIS < thresh2) continue;
                  //if (!allowedSubStates[start][split][lState][lp]) continue;

                  for (int rp = 0; rp < nRightChildStates; rp++) {
                    if (scores[lp][rp]==null) continue;
                    double rIS = iScore[split][end][rState][rp];
                    if (rIS == 0) continue;

                    double leftChildScore = maxcScore[start][split][lState][lp];
                    double rightChildScore = maxcScore[split][end][rState][rp];
                    if (leftChildScore==initVal||rightChildScore==initVal) continue;
                   	double gScore = leftChildScore + scalingFactor + rightChildScore;

                    for (int np = 0; np < nParentStates; np++) {
                      double pOS = oScore[start][end][pState][np];
                      if (pOS == 0) continue;

                      double scoreToBeat = maxcScore[start][end][pState][np];
                     	if (gScore < scoreToBeat) continue; // no chance of finding a better derivation

                      double ruleS = scores[lp][rp][np];
                      if (ruleS == 0) continue;
                      ruleScore = (pOS * ruleS * lIS * rIS) / logNormalizer;

                      if (ruleScore==0) continue;
                      if (doVariational){
                      	ruleScore /= oScore[start][end][pState][np]/logNormalizer*iScore[start][end][pState][np];
                      }
                      
                      ruleScore = gScore + Math.log(ruleScore);

                      if (ruleScore > scoreToBeat) {
                        maxcScore[start][end][pState][np] = ruleScore;
                        maxcSplit[start][end][pState][np] = split;
                        maxcLeftChild[start][end][pState][np] = lState;
                        maxcRightChild[start][end][pState][np] = rState;
                        maxcLeftChildSub[start][end][pState][np] = lp;
                        maxcRightChildSub[start][end][pState][np] = rp;
                      }
                      
                    }
                  }
                }
              }
            } 
          }
        } else { // diff == 1
          // We treat TAG --> word exactly as if it was a unary rule, except the score of the rule is
          // given by the lexicon rather than the grammar and that we allow another unary on top of it.
          //for (int tag : lexicon.getAllTags()){
        	for (int tag=0; tag<numSubStatesArray.length; tag++){
            if (!allowedStates[start][end][tag]) continue;
  				  int nTagStates = numSubStatesArray[tag];
            String word = sentence.get(start);
            //System.out.print("Attempting");
            if (grammar.isGrammarTag(tag)) continue;
            //System.out.println("Computing maxcScore for span " +start + " to "+end);
            double[] lexiconScoreArray = lexicon.score(word, (short) tag, start, false,false);
            double lexiconScores = 0;
            for (int tp = 0; tp < nTagStates; tp++) {
              double pOS = oScore[start][end][tag][tp];
//              if (pOS < thresh2) continue;
              double ruleS = lexiconScoreArray[tp];
              lexiconScores = (pOS * ruleS) / logNormalizer; // The inside score of a word is 0.0f
             	double scalingFactor = 0.0;
              if (doVariational) lexiconScores = 1;
              else if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor(
             			oScale[start][end][tag]-iScale[0][length][0]));

              maxcScore[start][end][tag][tp] = Math.log(lexiconScores) + scalingFactor;
            }
          }
        }
        // Try unary rules
        // Replacement for maxcScore[start][end], which is updated in batch   
        double[][] maxcScoreStartEnd = new double[numStates][];
        for (int i = 0; i < numStates; i++) {
    			if (!allowedStates[start][end][i]) continue;
    			maxcScoreStartEnd[i] = new double[numSubStatesArray[i]];
        	for (int j=0; j<numSubStatesArray[i]; j++){
        		maxcScoreStartEnd[i][j] = maxcScore[start][end][i][j];
        	}
        }
//        double[] unaryBonus = new double[numStates];
//        int[] unaryChild = new int[numStates];
        double[][] ruleScores = null;
        if (doVariational) ruleScores = new double[numStates][numStates];
        boolean foundOne = false;
      	for (int pState=0; pState<numSubStatesArray.length; pState++){
          if (!allowedStates[start][end][pState]) continue;
          int nParentStates = numSubStatesArray[pState]; // == scores[0].length;
          UnaryRule[] unaries = grammar.getClosedSumUnaryRulesByParent(pState);
          if (doVariational)
          	unaries = grammar.getUnaryRulesByParent(pState).toArray(new UnaryRule[0]);
          for (int r = 0; r < unaries.length; r++) {
            UnaryRule ur = unaries[r];
            int cState = ur.childState;
            if ((pState == cState)) continue;// && (np == cp))continue;
            if (iScore[start][end][cState]==null) continue;

           	double scalingFactor = 0.0;
           	if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor(
            		oScale[start][end][pState]+iScale[start][end][cState]
            		-iScale[0][length][0]));
           	

            double[][] scores = ur.getScores2();
            int nChildStates = numSubStatesArray[cState]; // == scores.length;
            double ruleScore = 0;
            for (int cp = 0; cp < nChildStates; cp++) {
              double cIS = iScore[start][end][cState][cp];
              if (cIS == 0) continue;

              double childScore = maxcScore[start][end][cState][cp];
              if (childScore==initVal) continue;
              
              if (scores[cp]==null) continue;
              for (int np = 0; np < nParentStates; np++) {
                double pOS = oScore[start][end][pState][np];
                if (pOS < 0) continue;

                double gScore = scalingFactor + childScore;
                if (gScore < maxcScoreStartEnd[pState][np]) continue;

                double ruleS = scores[cp][np];
                if (ruleS == 0) continue;
                ruleScore = (pOS * ruleS * cIS) / logNormalizer;
                foundOne = true;

                if (ruleScore==0) continue;
                if (doVariational){
                	ruleScore /= oScore[start][end][pState][np]/logNormalizer*iScore[start][end][pState][np];
                }

                ruleScore = gScore + Math.log(ruleScore);


                if (ruleScore > maxcScoreStartEnd[pState][np]) {
                  maxcScoreStartEnd[pState][np] = ruleScore;
                  maxcChild[start][end][pState][np] = cState;
                  maxcChildSub[start][end][pState][np] = cp;
                }
              }
            }
          }
        }
//      	for (int i = 0; i < numStates; i++) {
//      		if (maxcScore[start][end][i]+(1-unaryBonus[i]) > maxcScoreStartEnd[i]){
//      			maxcScore[start][end][i]+=(1-unaryBonus[i]);
//      		} else {
//      			maxcScore[start][end][i] = maxcScoreStartEnd[i];
//      			maxcChild[start][end][i] = unaryChild[i];
//      		}
//        }
//      	if (foundOne&&doVariational) maxcScoreStartEnd = closeVariationalRules(ruleScores,start,end);
      	maxcScore[start][end] = maxcScoreStartEnd;
      }
    }
  }

  
  public Tree<String> extractBestMaxRuleParse(int start, int end, List<String> sentence ) {
  	return extractBestMaxRuleParse1(start, end, 0, 0, sentence);
  }
  /**
   * Returns the best parse for state "state", potentially starting with a unary rule
   */
  public Tree<String> extractBestMaxRuleParse1(int start, int end, int state, int substate, List<String> sentence ) {
  	//System.out.println(start+", "+end+";");
  	int cState = maxcChild[start][end][state][substate];
  	int cSubState = maxcChildSub[start][end][state][substate];
    if (cState == -1) {
      return extractBestMaxRuleParse2(start, end, state, substate, sentence);
    } else {
      List<Tree<String>> child = new ArrayList<Tree<String>>();
      child.add( extractBestMaxRuleParse2(start, end, cState, cSubState, sentence) );
      String stateStr = (String) tagNumberer.object(state);
  		if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2);

      totalUsedUnaries++;
      //System.out.println("Adding a unary spanning from "+start+" to "+end+". P: "+stateStr+" C: "+child.get(0).getLabel());
      int intermediateNode = grammar.getUnaryIntermediate((short)state,(short)cState);
//      if (intermediateNode==0){
//      	System.out.println("Added a bad unary from "+start+" to "+end+". P: "+stateStr+" C: "+child.get(0).getLabel());
//      }
      if (intermediateNode>0){
        List<Tree<String>> restoredChild = new ArrayList<Tree<String>>();
        nTimesRestoredUnaries++;
        String stateStr2 = (String)tagNumberer.object(intermediateNode);
        if (stateStr2.endsWith("^g")) stateStr2 = stateStr2.substring(0,stateStr2.length()-2);
        restoredChild.add(new Tree<String>(stateStr2, child));
        //System.out.println("Restored a unary from "+start+" to "+end+": "+stateStr+" -> "+stateStr2+" -> "+child.get(0).getLabel());
      	return new Tree<String>(stateStr,restoredChild);
	    }
      return new Tree<String>(stateStr, child);
    }
  }

  /**
   * Returns the best parse for state "state", but cannot start with a unary
   */
  public Tree<String> extractBestMaxRuleParse2(int start, int end, int state, int substate, List<String> sentence ) {
    List<Tree<String>> children = new ArrayList<Tree<String>>();
    String stateStr = (String)tagNumberer.object(state);//+""+start+""+end;
    if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2);
    boolean posLevel = (end - start == 1);
    if (posLevel) {
    	if (grammar.isGrammarTag(state)){
        List<Tree<String>> childs = new ArrayList<Tree<String>>();
        childs.add(new Tree<String>(sentence.get(start)));
        String stateStr2 = (String)tagNumberer.object(maxcChild[start][end][state][substate]);//+""+start+""+end;
        children.add(new Tree<String>(stateStr2,childs));
    	}
    	else children.add(new Tree<String>(sentence.get(start)));
    } else {
      int split = maxcSplit[start][end][state][substate];
      if (split == -1) {
        System.err.println("Warning: no symbol can generate the span from "+ start+ " to "+end+".");
        System.err.println("The score is "+maxcScore[start][end][state]+" and the state is supposed to be "+stateStr);
        System.err.println("The insideScores are "+Arrays.toString(iScore[start][end][state])+" and the outsideScores are " +Arrays.toString(oScore[start][end][state]));
        System.err.println("The maxcScore is "+maxcScore[start][end][state]);
        //return  extractBestMaxRuleParse2(start, end, maxcChild[start][end][state], sentence);
        return  new Tree<String>("ROOT");      
      }
      int lState = maxcLeftChild[start][end][state][substate];
      int lSubState = maxcLeftChildSub[start][end][state][substate];
      int rState = maxcRightChild[start][end][state][substate];
      int rSubState = maxcRightChildSub[start][end][state][substate];
      Tree<String> leftChildTree = extractBestMaxRuleParse1(start, split, lState, lSubState, sentence);
      Tree<String> rightChildTree = extractBestMaxRuleParse1(split, end, rState, rSubState, sentence);
      children.add(leftChildTree);
      children.add(rightChildTree);
    }
    return new Tree<String>(stateStr, children);
  }


	
}
